# Copyright (c) Alibaba, Inc. and its affiliates.
import re
from copy import deepcopy
from dataclasses import dataclass, field
from typing import List, Literal, Optional

from swift.llm.utils import update_generation_config_eos_token
from swift.plugin import extra_tuners
from swift.tuners import Swift
from swift.utils import get_logger
from ..utils import Messages

logger = get_logger()


@dataclass
class InferCliState:
    # None: use default-system. '': not use system.
    system: Optional[str] = None
    messages: Messages = field(default_factory=list)  # not including system

    images: List[str] = field(default_factory=list)
    audios: List[str] = field(default_factory=list)
    videos: List[str] = field(default_factory=list)

    multiline_mode: bool = False
    input_system: bool = False

    def clear(self):
        self.messages = []
        self.images = []
        self.audios = []
        self.videos = []

    def add_query(self, query: str) -> None:
        role = "user"
        if query.startswith("tool:"):
            role = "tool"
            query = query[len("tool:") :]
        self.messages.append({"role": role, "content": query})

    def add_response(self, response: str) -> None:
        self.messages.append({"role": "assistant", "content": response})

    def to_dict(self):
        infer_state = deepcopy(self)
        if infer_state.system is not None:
            infer_state.messages.insert(
                0, {"role": "system", "content": infer_state.system}
            )
        return {
            "messages": infer_state.messages,
            "images": infer_state.images,
            "audios": infer_state.audios,
            "videos": infer_state.videos,
        }

    def input_mm_data(self) -> None:

        def _input_mm_file(mm_type: Literal["image", "video", "audio"]) -> str:
            a_an = "an" if mm_type[0] in {"i", "a"} else "a"
            return input(f"Input {a_an} {mm_type} path or URL <<< ")

        mm_types = ["image", "video", "audio"]
        query = self.messages[-1]["content"]
        mm_tags = re.findall("|".join(f"<{mm_type}>" for mm_type in mm_types), query)
        # mm_tag -> mm_type/mm_key
        mm_mapping = {f"<{mm_type}>": (mm_type, f"{mm_type}s") for mm_type in mm_types}
        for mm_tag in mm_tags:
            mm_type, mm_key = mm_mapping[mm_tag]
            mm_val = getattr(self, mm_key)
            mm_val.append(_input_mm_file(mm_type))

    @staticmethod
    def _input_multiline(prompt: str) -> str:
        query = ""
        stop_words = "#\n"
        while True:
            text = f"{input(prompt)}\n"
            prompt = ""
            if text.endswith(stop_words):
                query += text[: -len(stop_words)]
                break
            query += text
        return query

    def input_text(self) -> str:
        if self.multiline_mode:
            addi_prompt = "[MS]" if self.input_system else "[M]"
            text = InferCliState._input_multiline(f"<<<{addi_prompt} ")
        else:
            addi_prompt = "[S]" if self.input_system else ""
            text = input(f"<<<{addi_prompt} ")
        return text

    def check_query(self, query: str) -> Optional[str]:
        query_std = query.strip().lower()
        if self.input_system:
            if query == "default-system":
                self.system = None
            else:
                self.system = query
            self.input_system = False
            query_std = "clear"
        if query_std == "clear":
            self.clear()
            return
        if query_std == "":
            return
        if query_std == "reset-system":
            self.input_system = True
            return
        if query_std == "multi-line":
            self.multiline_mode = True
            logger.info("End multi-line input with `#`.")
            logger.info("Input `single-line` to switch to single-line input mode.")
            return
        if query_std == "single-line":
            self.multiline_mode = False
            return
        return query


def prepare_adapter(args, model, adapters=None):
    if args.tuner_backend == "unsloth":
        if args.model_meta.is_multimodal:
            from unsloth import FastVisionModel as UnslothModel
        else:
            from unsloth import FastLanguageModel as UnslothModel
        UnslothModel.for_inference(model)
        return model
    if args.train_type in extra_tuners:
        tuner = extra_tuners[args.train_type]
    else:
        tuner = Swift
    # compat deploy
    adapters = adapters or args.adapters
    for adapter in adapters:
        model = tuner.from_pretrained(model, adapter)
    if args.train_type == "bone":
        # Bone has a problem of float32 matmul with bloat16 in `peft==0.14.0`
        model.to(model.dtype)
    return model


def prepare_model_template(args, **kwargs):
    model, processor = args.get_model_processor(**kwargs)
    model = prepare_adapter(args, model)
    template = args.get_template(processor)
    update_generation_config_eos_token(model.generation_config, template)
    return model, template
